/*
 * Copyright 2008 Web Cohesion
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.security.oauth.consumer.client;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.ProtocolException;
import java.net.Proxy;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLEncoder;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.security.oauth.common.OAuthConsumerParameter;
import org.springframework.security.oauth.common.signature.HMAC_SHA1SignatureMethod;
import org.springframework.security.oauth.common.signature.OAuthSignatureMethod;
import org.springframework.security.oauth.common.signature.OAuthSignatureMethodFactory;
import org.springframework.security.oauth.common.signature.SharedConsumerSecret;
import org.springframework.security.oauth.common.signature.SharedConsumerSecretImpl;
import org.springframework.security.oauth.consumer.InvalidOAuthRealmException;
import org.springframework.security.oauth.consumer.OAuthConsumerToken;
import org.springframework.security.oauth.consumer.OAuthRequestFailedException;
import org.springframework.security.oauth.consumer.ProtectedResourceDetails;
import org.springframework.security.oauth.consumer.net.DefaultOAuthURLStreamHandlerFactory;

import sun.net.www.protocol.http.Handler;

/**
 * @author Ryan Heaton
 */
@SuppressWarnings("restriction")
@RunWith(MockitoJUnitRunner.class)
public class CoreOAuthConsumerSupportTests {
	@Mock
	private ProtectedResourceDetails details;

	/**
	 * afterPropertiesSet
	 */
	@Test
	public void testAfterPropertiesSet() throws Exception {
		try {
			new CoreOAuthConsumerSupport().afterPropertiesSet();
			fail("should required a protected resource details service.");
		}
		catch (IllegalArgumentException e) {
		}
	}

	/**
	 * readResouce
	 */
	@Test
	public void testReadResouce() throws Exception {

		OAuthConsumerToken token = new OAuthConsumerToken();
		URL url = new URL("https://myhost.com/resource?with=some&query=params&too");
		final ConnectionProps connectionProps = new ConnectionProps();
		final ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[0]);

		final HttpURLConnectionForTestingPurposes connectionMock = new HttpURLConnectionForTestingPurposes(url) {
			@Override
			public void setRequestMethod(String method) throws ProtocolException {
				connectionProps.method = method;
			}

			@Override
			public void setDoOutput(boolean dooutput) {
				connectionProps.doOutput = dooutput;
			}

			@Override
			public void connect() throws IOException {
				connectionProps.connected = true;
			}

			@Override
			public OutputStream getOutputStream() throws IOException {
				ByteArrayOutputStream out = new ByteArrayOutputStream();
				connectionProps.outputStream = out;
				return out;
			}

			@Override
			public int getResponseCode() throws IOException {
				return connectionProps.responseCode;
			}

			@Override
			public String getResponseMessage() throws IOException {
				return connectionProps.responseMessage;
			}

			@Override
			public InputStream getInputStream() throws IOException {
				return inputStream;
			}

			@Override
			public String getHeaderField(String name) {
				return connectionProps.headerFields.get(name);
			}
		};

		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport() {
			@Override
			public URL configureURLForProtectedAccess(URL url, OAuthConsumerToken accessToken,
					ProtectedResourceDetails details, String httpMethod, Map<String, String> additionalParameters)
					throws OAuthRequestFailedException {
				try {
					return new URL(url.getProtocol(), url.getHost(), url.getPort(), url.getFile(),
							new StreamHandlerForTestingPurposes(connectionMock));
				}
				catch (MalformedURLException e) {
					throw new RuntimeException(e);
				}
			}

			@Override
			public String getOAuthQueryString(ProtectedResourceDetails details, OAuthConsumerToken accessToken,
					URL url, String httpMethod, Map<String, String> additionalParameters) {
				return "POSTBODY";
			}
		};
		support.setStreamHandlerFactory(new DefaultOAuthURLStreamHandlerFactory());

		when(details.getAuthorizationHeaderRealm()).thenReturn("realm1");
		when(details.isAcceptsAuthorizationHeader()).thenReturn(true);
		when(details.getAdditionalRequestHeaders()).thenReturn(null);
		try {
			support.readResource(details, url, "POST", token, null, null);
			fail("shouldn't have been a valid response code.");
		}
		catch (OAuthRequestFailedException e) {
			// fall through...
		}
		assertFalse(connectionProps.doOutput);
		assertEquals("POST", connectionProps.method);
		assertTrue(connectionProps.connected);
		connectionProps.reset();

		when(details.getAuthorizationHeaderRealm()).thenReturn(null);
		when(details.isAcceptsAuthorizationHeader()).thenReturn(true);
		when(details.getAdditionalRequestHeaders()).thenReturn(null);
		connectionProps.responseCode = 400;
		connectionProps.responseMessage = "Nasty";
		try {
			support.readResource(details, url, "POST", token, null, null);
			fail("shouldn't have been a valid response code.");
		}
		catch (OAuthRequestFailedException e) {
			// fall through...
		}
		assertFalse(connectionProps.doOutput);
		assertEquals("POST", connectionProps.method);
		assertTrue(connectionProps.connected);
		connectionProps.reset();

		when(details.getAuthorizationHeaderRealm()).thenReturn(null);
		when(details.isAcceptsAuthorizationHeader()).thenReturn(true);
		when(details.getAdditionalRequestHeaders()).thenReturn(null);
		connectionProps.responseCode = 401;
		connectionProps.responseMessage = "Bad Realm";
		connectionProps.headerFields.put("WWW-Authenticate", "realm=\"goodrealm\"");
		try {
			support.readResource(details, url, "POST", token, null, null);
			fail("shouldn't have been a valid response code.");
		}
		catch (InvalidOAuthRealmException e) {
			// fall through...
		}
		assertFalse(connectionProps.doOutput);
		assertEquals("POST", connectionProps.method);
		assertTrue(connectionProps.connected);
		connectionProps.reset();

		when(details.getAuthorizationHeaderRealm()).thenReturn(null);
		when(details.isAcceptsAuthorizationHeader()).thenReturn(true);
		when(details.getAdditionalRequestHeaders()).thenReturn(null);
		connectionProps.responseCode = 200;
		connectionProps.responseMessage = "Congrats";
		assertSame(inputStream, support.readResource(details, url, "GET", token, null, null));
		assertFalse(connectionProps.doOutput);
		assertEquals("GET", connectionProps.method);
		assertTrue(connectionProps.connected);
		connectionProps.reset();

		when(details.getAuthorizationHeaderRealm()).thenReturn(null);
		when(details.isAcceptsAuthorizationHeader()).thenReturn(false);
		when(details.getAdditionalRequestHeaders()).thenReturn(null);
		connectionProps.responseCode = 200;
		connectionProps.responseMessage = "Congrats";
		assertSame(inputStream, support.readResource(details, url, "POST", token, null, null));
		assertEquals("POSTBODY", new String(((ByteArrayOutputStream) connectionProps.outputStream).toByteArray()));
		assertTrue(connectionProps.doOutput);
		assertEquals("POST", connectionProps.method);
		assertTrue(connectionProps.connected);
		connectionProps.reset();

	}

	/**
	 * configureURLForProtectedAccess
	 */
	@Test
	public void testConfigureURLForProtectedAccess() throws Exception {
		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport() {
			// Inherited.
			@Override
			public String getOAuthQueryString(ProtectedResourceDetails details, OAuthConsumerToken accessToken,
					URL url, String httpMethod, Map<String, String> additionalParameters) {
				return "myquerystring";
			}
		};
		support.setStreamHandlerFactory(new DefaultOAuthURLStreamHandlerFactory());
		OAuthConsumerToken token = new OAuthConsumerToken();
		URL url = new URL("https://myhost.com/somepath?with=some&query=params&too");

		when(details.isAcceptsAuthorizationHeader()).thenReturn(true);
		assertEquals("https://myhost.com/somepath?with=some&query=params&too",
				support.configureURLForProtectedAccess(url, token, details, "GET", null).toString());

		when(details.isAcceptsAuthorizationHeader()).thenReturn(false);
		assertEquals("https://myhost.com/somepath?myquerystring",
				support.configureURLForProtectedAccess(url, token, details, "GET", null).toString());

		assertEquals("https://myhost.com/somepath?with=some&query=params&too",
				support.configureURLForProtectedAccess(url, token, details, "POST", null).toString());

		assertEquals("https://myhost.com/somepath?with=some&query=params&too",
				support.configureURLForProtectedAccess(url, token, details, "PUT", null).toString());
	}

	/**
	 * test getAuthorizationHeader
	 */
	@Test
	public void testGetAuthorizationHeader() throws Exception {
		final TreeMap<String, Set<CharSequence>> params = new TreeMap<String, Set<CharSequence>>();
		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport() {
			@Override
			protected Map<String, Set<CharSequence>> loadOAuthParameters(ProtectedResourceDetails details,
					URL requestURL, OAuthConsumerToken requestToken, String httpMethod,
					Map<String, String> additionalParameters) {
				return params;
			}
		};
		URL url = new URL("https://myhost.com/somepath?with=some&query=params&too");
		OAuthConsumerToken token = new OAuthConsumerToken();

		when(details.isAcceptsAuthorizationHeader()).thenReturn(false);
		assertNull(support.getAuthorizationHeader(details, token, url, "POST", null));

		params.put("with", Collections.singleton((CharSequence) "some"));
		params.put("query", Collections.singleton((CharSequence) "params"));
		params.put("too", null);
		when(details.isAcceptsAuthorizationHeader()).thenReturn(true);
		when(details.getAuthorizationHeaderRealm()).thenReturn("myrealm");
		assertEquals("OAuth realm=\"myrealm\", query=\"params\", with=\"some\"",
				support.getAuthorizationHeader(details, token, url, "POST", null));

		params.put(OAuthConsumerParameter.oauth_consumer_key.toString(), Collections.singleton((CharSequence) "mykey"));
		params.put(OAuthConsumerParameter.oauth_nonce.toString(), Collections.singleton((CharSequence) "mynonce"));
		params.put(OAuthConsumerParameter.oauth_timestamp.toString(), Collections.singleton((CharSequence) "myts"));
		when(details.isAcceptsAuthorizationHeader()).thenReturn(true);
		when(details.getAuthorizationHeaderRealm()).thenReturn("myrealm");
		assertEquals(
				"OAuth realm=\"myrealm\", oauth_consumer_key=\"mykey\", oauth_nonce=\"mynonce\", oauth_timestamp=\"myts\", query=\"params\", with=\"some\"",
				support.getAuthorizationHeader(details, token, url, "POST", null));
	}

	/**
	 * getOAuthQueryString
	 */
	@Test
	public void testGetOAuthQueryString() throws Exception {
		final TreeMap<String, Set<CharSequence>> params = new TreeMap<String, Set<CharSequence>>();
		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport() {
			@Override
			protected Map<String, Set<CharSequence>> loadOAuthParameters(ProtectedResourceDetails details,
					URL requestURL, OAuthConsumerToken requestToken, String httpMethod,
					Map<String, String> additionalParameters) {
				return params;
			}
		};

		URL url = new URL("https://myhost.com/somepath?with=some&query=params&too");
		OAuthConsumerToken token = new OAuthConsumerToken();

		when(details.isAcceptsAuthorizationHeader()).thenReturn(true);
		params.put("with", Collections.singleton((CharSequence) "some"));
		params.put("query", Collections.singleton((CharSequence) "params"));
		params.put("too", null);
		params.put(OAuthConsumerParameter.oauth_consumer_key.toString(), Collections.singleton((CharSequence) "mykey"));
		params.put(OAuthConsumerParameter.oauth_nonce.toString(), Collections.singleton((CharSequence) "mynonce"));
		params.put(OAuthConsumerParameter.oauth_timestamp.toString(), Collections.singleton((CharSequence) "myts"));
		assertEquals("query=params&too&with=some", support.getOAuthQueryString(details, token, url, "POST", null));

		when(details.isAcceptsAuthorizationHeader()).thenReturn(false);
		params.put("with", Collections.singleton((CharSequence) "some"));
		params.put("query", Collections.singleton((CharSequence) "params"));
		params.put("too", null);
		params.put(OAuthConsumerParameter.oauth_consumer_key.toString(), Collections.singleton((CharSequence) "mykey"));
		params.put(OAuthConsumerParameter.oauth_nonce.toString(), Collections.singleton((CharSequence) "mynonce"));
		params.put(OAuthConsumerParameter.oauth_timestamp.toString(), Collections.singleton((CharSequence) "myts"));
		assertEquals("oauth_consumer_key=mykey&oauth_nonce=mynonce&oauth_timestamp=myts&query=params&too&with=some",
				support.getOAuthQueryString(details, token, url, "POST", null));

		when(details.isAcceptsAuthorizationHeader()).thenReturn(false);
		params.put("with", Collections.singleton((CharSequence) "some"));
		String encoded_space = URLEncoder.encode(" ", "utf-8");
		params.put("query", Collections.singleton((CharSequence) ("params spaced")));
		params.put("too", null);
		params.put(OAuthConsumerParameter.oauth_consumer_key.toString(), Collections.singleton((CharSequence) "mykey"));
		params.put(OAuthConsumerParameter.oauth_nonce.toString(), Collections.singleton((CharSequence) "mynonce"));
		params.put(OAuthConsumerParameter.oauth_timestamp.toString(), Collections.singleton((CharSequence) "myts"));
		assertEquals("oauth_consumer_key=mykey&oauth_nonce=mynonce&oauth_timestamp=myts&query=params" + encoded_space
				+ "spaced&too&with=some", support.getOAuthQueryString(details, token, url, "POST", null));
	}

	/**
	 * getTokenFromProvider
	 */
	@Test
	public void testGetTokenFromProvider() throws Exception {
		final ByteArrayInputStream in = new ByteArrayInputStream(
				"oauth_token=mytoken&oauth_token_secret=mytokensecret".getBytes("UTF-8"));
		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport() {
			@Override
			protected InputStream readResource(ProtectedResourceDetails details, URL url, String httpMethod,
					OAuthConsumerToken token, Map<String, String> additionalParameters,
					Map<String, String> additionalRequestHeaders) {
				return in;
			}
		};

		URL url = new URL("https://myhost.com/somepath?with=some&query=params&too");

		when(details.getId()).thenReturn("resourceId");
		OAuthConsumerToken token = support.getTokenFromProvider(details, url, "POST", null, null);
		assertFalse(token.isAccessToken());
		assertEquals("mytoken", token.getValue());
		assertEquals("mytokensecret", token.getSecret());
		assertEquals("resourceId", token.getResourceId());

	}

	/**
	 * loadOAuthParameters
	 */
	@Test
	public void testLoadOAuthParameters() throws Exception {
		URL url = new URL("https://myhost.com/somepath?with=some&query=params&too");
		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport() {
			@Override
			protected String getSignatureBaseString(Map<String, Set<CharSequence>> oauthParams, URL requestURL,
					String httpMethod) {
				return "MYSIGBASESTRING";
			}
		};
		OAuthSignatureMethodFactory sigFactory = mock(OAuthSignatureMethodFactory.class);
		support.setSignatureFactory(sigFactory);
		OAuthConsumerToken token = new OAuthConsumerToken();
		OAuthSignatureMethod sigMethod = mock(OAuthSignatureMethod.class);

		when(details.getConsumerKey()).thenReturn("my-consumer-key");
		when(details.getSignatureMethod()).thenReturn(HMAC_SHA1SignatureMethod.SIGNATURE_NAME);
		when(details.getSignatureMethod()).thenReturn(HMAC_SHA1SignatureMethod.SIGNATURE_NAME);
		SharedConsumerSecret secret = new SharedConsumerSecretImpl("shh!!!");
		when(details.getSharedSecret()).thenReturn(secret);
		when(sigFactory.getSignatureMethod(HMAC_SHA1SignatureMethod.SIGNATURE_NAME, secret, null))
				.thenReturn(sigMethod);
		when(sigMethod.sign("MYSIGBASESTRING")).thenReturn("MYSIGNATURE");

		Map<String, Set<CharSequence>> params = support.loadOAuthParameters(details, url, token, "POST", null);
		assertEquals("some", params.remove("with").iterator().next().toString());
		assertEquals("params", params.remove("query").iterator().next().toString());
		assertTrue(params.containsKey("too"));
		assertTrue(params.remove("too").isEmpty());
		assertNull(params.remove(OAuthConsumerParameter.oauth_token.toString()));
		assertNotNull(params.remove(OAuthConsumerParameter.oauth_nonce.toString()).iterator().next());
		assertEquals("my-consumer-key", params.remove(OAuthConsumerParameter.oauth_consumer_key.toString()).iterator()
				.next());
		assertEquals("MYSIGNATURE", params.remove(OAuthConsumerParameter.oauth_signature.toString()).iterator().next());
		assertEquals("1.0", params.remove(OAuthConsumerParameter.oauth_version.toString()).iterator().next());
		assertEquals(HMAC_SHA1SignatureMethod.SIGNATURE_NAME,
				params.remove(OAuthConsumerParameter.oauth_signature_method.toString()).iterator().next());
		assertTrue(Long.parseLong(params.remove(OAuthConsumerParameter.oauth_timestamp.toString()).iterator().next()
				.toString()) <= (System.currentTimeMillis() / 1000));
		assertTrue(params.isEmpty());
	}

	/**
	 * tests getting the signature base string.
	 */
	@Test
	public void testGetSignatureBaseString() throws Exception {
		Map<String, Set<CharSequence>> oauthParams = new HashMap<String, Set<CharSequence>>();
		oauthParams.put("oauth_consumer_key", Collections.singleton((CharSequence) "dpf43f3p2l4k3l03"));
		oauthParams.put("oauth_token", Collections.singleton((CharSequence) "nnch734d00sl2jdk"));
		oauthParams.put("oauth_signature_method", Collections.singleton((CharSequence) "HMAC-SHA1"));
		oauthParams.put("oauth_timestamp", Collections.singleton((CharSequence) "1191242096"));
		oauthParams.put("oauth_nonce", Collections.singleton((CharSequence) "kllo9940pd9333jh"));
		oauthParams.put("oauth_version", Collections.singleton((CharSequence) "1.0"));
		oauthParams.put("file", Collections.singleton((CharSequence) "vacation.jpg"));
		oauthParams.put("size", Collections.singleton((CharSequence) "original"));

		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport();

		String baseString = support.getSignatureBaseString(oauthParams, new URL("https://photos.example.net/photos"),
				"geT");
		assertEquals(
				"GET&https%3A%2F%2Fphotos.example.net%2Fphotos&file%3Dvacation.jpg%26oauth_consumer_key%3Ddpf43f3p2l4k3l03%26oauth_nonce%3Dkllo9940pd9333jh%26oauth_signature_method%3DHMAC-SHA1%26oauth_timestamp%3D1191242096%26oauth_token%3Dnnch734d00sl2jdk%26oauth_version%3D1.0%26size%3Doriginal",
				baseString);
	}

	@Test
	public void testGetSignatureBaseStringSimple() throws Exception {
		Map<String, Set<CharSequence>> oauthParams = new HashMap<String, Set<CharSequence>>();
		oauthParams.put("foo", Collections.singleton((CharSequence) "bar"));
		oauthParams.put("bar", new LinkedHashSet<CharSequence>(Arrays.<CharSequence> asList("120", "24")));

		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport();

		String baseString = support.getSignatureBaseString(oauthParams, new URL("https://photos.example.net/photos"),
				"get");
		assertEquals("GET&https%3A%2F%2Fphotos.example.net%2Fphotos&bar%3D120%26bar%3D24%26foo%3Dbar", baseString);
	}

	// See SECOAUTH-383
	@Test
	public void testGetSignatureBaseStringMultivaluedLast() throws Exception {
		Map<String, Set<CharSequence>> oauthParams = new HashMap<String, Set<CharSequence>>();
		oauthParams.put("foo", Collections.singleton((CharSequence) "bar"));
		oauthParams.put("pin", new LinkedHashSet<CharSequence>(Arrays.<CharSequence> asList("2", "1")));

		CoreOAuthConsumerSupport support = new CoreOAuthConsumerSupport();

		String baseString = support.getSignatureBaseString(oauthParams, new URL("https://photos.example.net/photos"),
				"get");
		assertEquals("GET&https%3A%2F%2Fphotos.example.net%2Fphotos&foo%3Dbar%26pin%3D1%26pin%3D2", baseString);
	}

	static class StreamHandlerForTestingPurposes extends Handler {

		private final HttpURLConnectionForTestingPurposes connection;

		public StreamHandlerForTestingPurposes(HttpURLConnectionForTestingPurposes connection) {
			this.connection = connection;
		}

		@Override
		protected URLConnection openConnection(URL url) throws IOException {
			return connection;
		}

		@Override
		protected URLConnection openConnection(URL url, Proxy proxy) throws IOException {
			return connection;
		}
	}

	static class HttpURLConnectionForTestingPurposes extends HttpURLConnection {

		/**
		 * Constructor for the HttpURLConnection.
		 * 
		 * @param u the URL
		 */
		public HttpURLConnectionForTestingPurposes(URL u) {
			super(u);
		}

		public void disconnect() {
		}

		public boolean usingProxy() {
			return false;
		}

		public void connect() throws IOException {
		}
	}

	static class ConnectionProps {

		public int responseCode;

		public String responseMessage;

		public String method;

		public Boolean doOutput;

		public Boolean connected;

		public OutputStream outputStream;

		public final Map<String, String> headerFields = new TreeMap<String, String>();

		public void reset() {
			this.responseCode = 0;
			this.responseMessage = null;
			this.method = null;
			this.doOutput = null;
			this.connected = null;
			this.outputStream = null;
			this.headerFields.clear();
		}

	}
}
